// Copyright (c) 2022 ChenJun
// Licensed under the MIT License.

// OpenCV
#include <opencv2/core.hpp>
#include <opencv2/core/base.hpp>
#include <opencv2/core/mat.hpp>
#include <opencv2/core/types.hpp>
#include <opencv2/imgproc.hpp>

// STD
#include <cmath>
#include <memory>
#include <vector>

#include "hnurm_detect/armor.hpp"
#include "hnurm_detect/detector.hpp"

constexpr int CHANNEL_BLUE = 0;
constexpr int CHANNEL_RED = 2;

namespace hnurm
{
    Detector::Detector(
        int min_lightness,
        const LightParams &light_params,
        const ArmorParams &armor_params,
        const ClassifierParams &classifier_params
    )
        : min_lightness(min_lightness),
          light_params(light_params),
          armor_params(armor_params)
    {
        auto &[model_path, label_path, ignore_classes, threshold] = classifier_params;
        this->classifier = std::make_unique<NumberClassifier>(model_path, label_path, threshold, ignore_classes);
    }

    void Detector::detect(const cv::Mat &input, int self_color)
    {
        binary_img = preprocessImage(input);
        lights_ = findLights(input, binary_img);
        armors_ = matchLights(lights_, self_color);

        if (!armors_.empty())
        {
            classifier->extractNumbers(input, armors_);
            classifier->classify(armors_, self_color);
        }
    }

    cv::Mat Detector::getAllNumbersImage()
    {
        if (armors_.empty())
        {
            return {cv::Size(20, 28), CV_8UC1};
        }
        else
        {
            std::vector<cv::Mat> number_imgs;
            number_imgs.reserve(armors_.size());
            for (auto &armor: armors_)
            {
                number_imgs.emplace_back(armor.number_img);
            }
            cv::Mat all_num_img;
            cv::vconcat(number_imgs, all_num_img);
            return all_num_img;
        }
    }

    void Detector::drawResults(cv::Mat &img)
    {
        // Draw Lights
        for (const auto &light: lights_)
        {
            cv::circle(img, light.top, 3, cv::Scalar(255, 255, 255), 1);
            cv::circle(img, light.bottom, 3, cv::Scalar(255, 255, 255), 1);
            auto line_color = light.color == RED ? cv::Scalar(255, 255, 0) : cv::Scalar(255, 0, 255);
            cv::line(img, light.top, light.bottom, line_color, 1);
        }

        // Draw armors
        for (const auto &armor: armors_)
        {
            cv::line(img, armor.left_light.top, armor.right_light.bottom, cv::Scalar(0, 255, 0), 2);
            cv::line(img, armor.left_light.bottom, armor.right_light.top, cv::Scalar(0, 255, 0), 2);
        }
        
        // Show numbers and confidence
        for (const auto &armor: armors_)
        {
            cv::putText(
                img,
                armor.classification_result,
                armor.left_light.top,
                cv::FONT_HERSHEY_SIMPLEX,
                0.8,
                cv::Scalar(0, 255, 255),
                2
            );
        }
    }

    cv::Mat Detector::preprocessImage(const cv::Mat &rgb_img) const
    {
        cv::Mat gray_img;
        cv::cvtColor(rgb_img, gray_img, cv::COLOR_RGB2GRAY);
    
        cv::Mat _binary_img;
        cv::threshold(gray_img, _binary_img, min_lightness, 255, cv::THRESH_BINARY);
    
        // 添加腐蚀+膨胀操作（开运算）
        // -------------------------------
        // 1. 定义结构元素（3x3椭圆核）
        cv::Mat kernel = cv::getStructuringElement(
            cv::MORPH_ELLIPSE, 
            cv::Size(2, 2)
        );
        
        // 2. 先腐蚀消除小斑点
        cv::erode(_binary_img, _binary_img, kernel, cv::Point(-1,-1), 1);
    
        // 3. 再膨胀恢复主要形状
        cv::dilate(_binary_img, _binary_img, kernel, cv::Point(-1,-1), 1);
        // -------------------------------
    
        return _binary_img;
    }
    
    cv::Point2d Detector::calculateMean(const std::vector<cv::Point> &points)
    {
        cv::Point2d mean(0.0, 0.0);
        for (const auto &point: points)
        {
            mean.x += point.x;
            mean.y += point.y;
        }
        mean.x /= points.size();
        mean.y /= points.size();
        return mean;
    }

    std::vector<cv::Point2d> Detector::centralizeData(const std::vector<cv::Point> &points, const cv::Point2d &mean)
    {
        std::vector<cv::Point2d> centralized;
        for (const auto &point: points)
        {
            centralized.push_back(cv::Point2d(point.x - mean.x, point.y - mean.y));
        }
        return centralized;
    }

// 计算协方差矩阵
    Eigen::Matrix2d Detector::calculateCovarianceMatrix(const std::vector<cv::Point2d> &centralized)
    {
        Eigen::Matrix2d covarianceMatrix;
        covarianceMatrix.setZero();
        for (const auto &point: centralized)
        {
            covarianceMatrix(0, 0) += point.x * point.x;
            covarianceMatrix(0, 1) += point.x * point.y;
            covarianceMatrix(1, 0) += point.y * point.x;
            covarianceMatrix(1, 1) += point.y * point.y;
        }
        covarianceMatrix /= centralized.size();
        return covarianceMatrix;
    }
    cv::Point2d Detector::performPCA(const std::vector<cv::Point> &points, const cv::Point2d &mean)
    {

        std::vector<cv::Point2d> centralized = centralizeData(points, mean);
        Eigen::Matrix2d covarianceMatrix = calculateCovarianceMatrix(centralized);

        // 计算特征值和特征向量
        Eigen::SelfAdjointEigenSolver < Eigen::Matrix2d > solver(covarianceMatrix);
        Eigen::Matrix2d eigenvectors = solver.eigenvectors();
//    Eigen::Vector2d eigenvalues = solver.eigenvalues();

        // 输出主成分方向
//    std::cout << "Principal Component 1: " << eigenvectors.col(1).transpose() << std::endl;
//    std::cout << "Principal Component 2: " << eigenvectors.col(0).transpose() << std::endl;
        cv::Point2d p1 = mean + cv::Point2d(eigenvectors(0, 1), eigenvectors(1, 1));
        return p1;
    }
    bool Detector::computeLineIntersection(const cv::Point2d &p1,
                                           const cv::Point2d &p2,
                                           const cv::Point2d &p3,
                                           const cv::Point2d &p4,
                                           cv::Point2d &intersection)
    {
        double A1 = p2.y - p1.y;
        double B1 = p1.x - p2.x;
        double C1 = A1 * p1.x + B1 * p1.y;

        double A2 = p4.y - p3.y;
        double B2 = p3.x - p4.x;
        double C2 = A2 * p3.x + B2 * p3.y;

        double det = A1 * B2 - A2 * B1;
        if (det == 0)
        {
            return false; // 平行线或重合线
        }
        else
        {
            intersection.x = (B2 * C1 - B1 * C2) / det;
            intersection.y = (A1 * C2 - A2 * C1) / det;
            return true;
        }
    }
    std::vector<Light> Detector::findLights(const cv::Mat &rbg_img, const cv::Mat &_binary_img)
    {
        using std::vector;
        vector<vector<cv::Point>> contours;
        cv::findContours(_binary_img, contours, cv::RETR_EXTERNAL, cv::CHAIN_APPROX_NONE);

        vector<Light> lights;

        for (const auto &contour: contours)
        {
            if (contour.size() < 5)
                continue;
            cv::Point2d mean = calculateMean(contour);
            cv::Point2d p1 = performPCA(contour, mean);
            auto r_rect = cv::minAreaRect(contour);
            auto light = Light(r_rect);

            if (isLight(light))
            {
                auto rect = light.boundingRect();
                if (// Avoid assertion failed
                    0 <= rect.x && 0 <= rect.width && rect.x + rect.width <= rbg_img.cols && 0 <= rect.y &&
                        0 <= rect.height && rect.y + rect.height <= rbg_img.rows)
                {
                    int sum_r = 0, sum_b = 0;
                    auto roi = rbg_img(rect);
                    // Iterate through the ROI
                    for (int row = 0; row < roi.rows; row++)
                    {
                        for (int j = 0; j < roi.cols; j++)
                        {
                            if (cv::pointPolygonTest(contour, cv::Point2f(j + rect.x, row + rect.y), false) >= 0)
                            {
                                // if point is inside contour
                                sum_r += roi.at<cv::Vec3b>(row, j)[CHANNEL_RED];
                                sum_b += roi.at<cv::Vec3b>(row, j)[CHANNEL_BLUE];
                            }
                        }
                    }

                    light.color = sum_r > sum_b ? RED : BLUE;
                    lights.emplace_back(light);
                }
                vector<cv::Point2d> intersections;
                cv::Point2d intersection;
                if (computeLineIntersection(mean, p1, light.p[0], light.p[1], intersection))
                {
                    intersections.push_back(intersection);
                }
                light.top = intersection;
                if (computeLineIntersection(mean, p1, light.p[2], light.p[3], intersection))
                {
                    intersections.push_back(intersection);
                }
                light.bottom = intersection;
                light.center = mean;
                light.tilt_angle =
                    std::atan2(std::abs(light.top.x - light.bottom.x), std::abs(light.top.y - light.bottom.y));
                light.tilt_angle = light.tilt_angle / CV_PI * 180;
            }
        };
        return lights;
    }

    bool Detector::isLight(const Light &light) const
    {
        // The ratio of light (short side / long side)
        auto ratio = static_cast<float>((light.width / light.length));
        bool ratio_ok = light_params.min_ratio < ratio && ratio < light_params.max_ratio;
        bool angle_ok = light.tilt_angle < light_params.max_angle;
        bool size_ok = light.width * light.length > 35; //40
        if(!size_ok){
            ratio_ok = ratio < 0.4 && ratio > 0.1;
            size_ok =  light.width * light.length <35 &&  light.width * light.length > 15;
            return ratio_ok && angle_ok && size_ok;
        }
        return ratio_ok && angle_ok && size_ok;
    }

    std::vector<Armor> Detector::matchLights(const std::vector<Light> &lights, int self_color)
    {
        std::vector<Armor> armors;

        // Loop all the pairing of lights
        for (int i = 0; i < lights.size(); i++)
        {
            if (lights[i].color == self_color)
                continue;
            for (int j = i + 1; j < lights.size(); j++)
            {
                if (lights[j].color == self_color)
                    continue;

                if (containLight(lights[i], lights[j], lights))
                {
                    continue;
                }
                auto armor = Armor(lights[i], lights[j]);
                if (isArmor(armor))
                {
                    armors.emplace_back(armor);
                }
            }
        }

        return armors;
    }

// Check if there is another light in the boundingRect formed by the 2 lights
    bool Detector::containLight(const Light &light_1, const Light &light_2, const std::vector<Light> &lights)
    {
        auto points = std::vector<cv::Point2f>{light_1.top, light_1.bottom, light_2.top, light_2.bottom};
        auto bounding_rect = cv::boundingRect(points);

        for (const auto &test_light: lights)
        {
            if (test_light.center == light_1.center || test_light.center == light_2.center)
                continue;

            if (bounding_rect.contains(test_light.top) || bounding_rect.contains(test_light.bottom)
                || bounding_rect.contains(test_light.center))
            {
                return true;
            }
        }

        return false;
    }

    bool Detector::isArmor(Armor &armor) const
    {
        Light light_1 = armor.left_light;
        Light light_2 = armor.right_light;
        // Ratio of the length of 2 lights (short side / long side)
        auto light_length_ratio
            = light_1.length < light_2.length ? light_1.length / light_2.length : light_2.length / light_1.length;
        bool light_ratio_ok = light_length_ratio > armor_params.min_light_ratio;

        // Distance between the center of 2 lights (unit : light length)
        auto avg_light_length = (light_1.length + light_2.length) / 2;
        auto center_distance = cv::norm(light_1.center - light_2.center) / avg_light_length;
        bool center_distance_ok = (armor_params.min_small_center_distance < center_distance
            && center_distance < armor_params.max_small_center_distance)
            || (armor_params.min_large_center_distance < center_distance
                && center_distance < armor_params.max_large_center_distance);

        // Angle of light center connection
        cv::Point2f diff = light_1.center - light_2.center;
        auto angle = std::abs(std::atan(diff.y / diff.x)) / CV_PI * 180;
        bool angle_ok = angle < armor_params.max_angle;
        bool abs_angle = abs(light_1.tilt_angle - light_2.tilt_angle) < 24;
        bool is_armor = light_ratio_ok && center_distance_ok && angle_ok && abs_angle;
        armor.armor_type = center_distance > armor_params.min_large_center_distance ? LARGE : SMALL;
        return is_armor;
    }

}  // namespace hnurm
